-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add gradient checks to avoid nan
gradients in TorchLearner
.
#47452
[RLlib] Add gradient checks to avoid nan
gradients in TorchLearner
.
#47452
Conversation
Signed-off-by: simonsays1980 <[email protected]>
…in highly unstable training phases. This helps to keep the optimizer's internal state intact whoch could get corrupted with many zero gradients. Furthermore, added better logging messages. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
nan
gradients in TorchLearner
.
@@ -176,7 +176,27 @@ def compute_gradients( | |||
def apply_gradients(self, gradients_dict: ParamDict) -> None: | |||
# Set the gradient of the parameters. | |||
for pid, grad in gradients_dict.items(): | |||
self._params[pid].grad = grad | |||
# If updates should not be skipped turn `nan` gradients to zero. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, I'm confused. We have this block here further below, which I think does the exact same thing: skips the entire optim.step
in case any gradient is non-finite (inf or nan).
# `step` the optimizer (default), but only if all gradients are finite.
elif all(
param.grad is None or torch.isfinite(param.grad).all()
for group in optim.param_groups
for param in group["params"]
):
Can you check and see whether these two logics can be consolidated?
Kind of like this:
- If user sets this flag (default=False), the optimizer will skip the update step entirely (+ warning raised by RLlib).
- If user does NOT set this flag (default behavior), grads that are non-finite will be set to 0.0 (+ warning raised by RLlib).
nan
gradients in TorchLearner
.nan
gradients in TorchLearner
.
…lution considers non-finite gradients and gives the user still the option to set such gradients to zero to keep the optimizer's internal state intact. Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @simonsays1980
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
…`. (ray-project#47452) Signed-off-by: ujjawal-khare <[email protected]>
Why are these changes needed?
If any gradients turn
nan
inTorchLearner
these gradients get added to the network's weights and in turn weights becomenan
and all network outputs as well. As a result the training errors out and stops. This PR proposes a gradient check to only add gradients if they are sane. It switchesnan
values to zeros in the gradients or skips en update entirely.The latter can be of advantage, if training phase ecnounters highly unstable policy updates (e.g. with highly explorative policies or during early stages of training). In such phases many gradients could turn
nan
and this may lead to corrupted internal optimizer states (e.g. Adam).Related issue number
#47451
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.